{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "import gym\n",
    "from torch.optim import Adam\n",
    "import random\n",
    "import torch\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "class MLP(nn.Module):\n",
    "    def __init__(self, input_dim, output_dim, hidden_dim):\n",
    "        super(MLP, self).__init__()\n",
    "        # define 4 linear layers\n",
    "        # 1 for the input\n",
    "        # 1 hiddem\n",
    "        # 1 for the policy output\n",
    "        # 1 for the state value ouput\n",
    "        self.fc1 = nn.Linear(input_dim, hidden_dim)\n",
    "        self.fc2 = nn.Linear(hidden_dim, hidden_dim)\n",
    "        self.pol_layer = nn.Linear(hidden_dim, output_dim)\n",
    "        self.val_layer = nn.Linear(hidden_dim, 1)\n",
    "    \n",
    "    def forward(self, x):\n",
    "        # create the forward computation of the network\n",
    "        # this function should return both the policy and the state value\n",
    "        h = F.relu(self.fc1(x))\n",
    "        h = F.relu(self.fc2(h))\n",
    "        policy = F.softmax(self.pol_layer(h), dim=-1)\n",
    "        value = self.val_layer(h)\n",
    "        return policy, value\n",
    "        "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "def hard_update(target, source):\n",
    "    for target_param, param in zip(target.parameters(), source.parameters()):\n",
    "        target_param.data.copy_(param.data)\n",
    "        "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "# do not change the seeds or the parameters\n",
    "# the code is fragile due to lack of many methods\n",
    "# experience replay, n-step returns, batch computations\n",
    "seed = 1\n",
    "torch.manual_seed(seed)\n",
    "np.random.seed(seed)\n",
    "num_actions = 2\n",
    "num_obs = 4\n",
    "hiddem_dim = 64\n",
    "network = MLP(num_obs, num_actions, hiddem_dim)\n",
    "target_network = MLP(num_obs, num_actions, hiddem_dim)\n",
    "hard_update(target_network, network)\n",
    "optimizer = Adam(network.parameters(), lr=0.0005)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "env = gym.make('CartPole-v1')\n",
    "env.seed(seed)\n",
    "act_num = 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Episode: 1 Reward: 13.0\n",
      "Episode: 51 Reward: 27.0\n",
      "Episode: 101 Reward: 30.0\n",
      "Episode: 151 Reward: 17.0\n",
      "Episode: 201 Reward: 57.0\n",
      "Episode: 251 Reward: 22.0\n",
      "Episode: 301 Reward: 36.0\n",
      "Episode: 351 Reward: 222.0\n"
     ]
    }
   ],
   "source": [
    "t = 0\n",
    "rewards = []\n",
    "average_rewards = []\n",
    "for ep_i in range(351):\n",
    "    done = False\n",
    "    obs = env.reset()\n",
    "    # transform the numpy array to torch tensor\n",
    "    obs = torch.Tensor(obs)\n",
    "\n",
    "    total_reward = 0\n",
    "    while not done:\n",
    "        # compute the policy for the current observation\n",
    "        policy, _ = network(obs)\n",
    "        \n",
    "        # sample from the categorical distribution that was created\n",
    "        a = np.random.choice([0,1], p=policy.detach().numpy())\n",
    "        \n",
    "        # for rendering remove the code below\n",
    "        # env.render()\n",
    "        \n",
    "        # make a step forward in the environemnt\n",
    "        next_obs, reward, done, _ = env.step(a)\n",
    "    \n",
    "        # transform the numpy array to torch tensor\n",
    "        next_obs = torch.Tensor(next_obs)\n",
    "        total_reward += reward\n",
    "            \n",
    "        t += 1\n",
    "        ########################################\n",
    "        #        BEGINNING OF TRAINING         #\n",
    "        ########################################\n",
    "        # initialized the optimzer\n",
    "        optimizer.zero_grad()\n",
    "        \n",
    "        # compute the target value r + \\gamma * V(s')\n",
    "        if not done:\n",
    "            target_value = torch.Tensor([reward]) + 0.99 * target_network(next_obs)[1]\n",
    "        else:\n",
    "            target_value = torch.Tensor([reward])\n",
    "        \n",
    "        # compute the value of this observation\n",
    "        value = network(obs)[1]\n",
    "        \n",
    "        # compute the advantage for the policy gradient\n",
    "        advantage = target_value - value\n",
    "        \n",
    "        # compute the td error\n",
    "        td_loss = 0.5 * (value - target_value.detach())**2\n",
    "        \n",
    "        # compute the policy gradient error L = - advantage * log(policy)\n",
    "        # make sure that you compute the policy gradient only for the action that was executed before\n",
    "        pg_loss = -advantage.detach() * torch.log(policy[a])\n",
    "        \n",
    "        # compute the entory for the current policy\n",
    "        # entropy = - sum(policy*log(policy))\n",
    "        entropy = -torch.sum(policy * torch.log(policy))\n",
    "        \n",
    "        # add the two errors and substract the entropy\n",
    "        loss = td_loss + pg_loss - 0.1 * entropy\n",
    "        \n",
    "        # compute the gradients using backprop\n",
    "        loss.backward()\n",
    "        \n",
    "        # make an optimization step\n",
    "        optimizer.step()\n",
    "        ########################################\n",
    "        #        END OF TRAINING               #\n",
    "        ########################################\n",
    "        \n",
    "        #update the current observation from the next observation\n",
    "        obs = next_obs\n",
    "        \n",
    "        # update the parameters of the target network using the \n",
    "        # function hard update\n",
    "        if t % 1000 == 0:\n",
    "            hard_update(target_network, network)\n",
    "        \n",
    "        if done:\n",
    "            if ep_i % 50 ==0:\n",
    "                print('Episode:', ep_i + 1, 'Reward:', total_reward)\n",
    "            rewards.append(total_reward)\n",
    "            if ep_i % 10 ==0:\n",
    "                average_rewards.append(sum(rewards) / 10.0)\n",
    "                rewards = []\n",
    "            break\n",
    "            \n",
    "        "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAYgAAAD8CAYAAABthzNFAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvOIA7rQAAIABJREFUeJzt3Xd8XOWd7/HPT8WSLavaai6ysC1bGAy2EWASio0JAVIgl5CQttyEXTZtl5RNFi6b3OwmpOym3GWXFPKCDYEESAgESCBUB5tm3Bsuko0tWc2SrJFkq4+e+8ccmcFI8tjWzJmRvu/Xa15z5swZzVfH8vzmPOc5z2POOURERI6V5HcAERGJTyoQIiIyJBUIEREZkgqEiIgMSQVCRESGpAIhIiJDUoEQEZEhqUCIiMiQVCBERGRIKX4HOBVTp051paWlfscQEUko69evb3bO5R9vu4QuEKWlpaxbt87vGCIiCcXM9keynZqYRERkSCoQIiIyJBUIEREZkgqEiIgMSQVCRESGpAIhIiJDUoEQEZEhqUCIiMSRHfXtvLKn2e8YgAqEiEhc+eHTu/jc/RvoDw74HUUFQkQkntQGumjr6mNTTcDvKCoQIiLxpC7QBcDKXQd9ThLFAmFmM81spZm9YWbbzexmb32emT1rZpXefa633szsDjOrMrMtZrYkWtlEROJRR3cf7d39ALyws8nnNNE9gugHvuqcWwAsBb5gZguAW4DnnXNlwPPeY4ArgTLvdhPwsyhmExGJO/Vt3QCcMS2LHfXtNHiP/RK1AuGcq3fObfCWO4AdwHTgauBeb7N7gWu85auBX7uQ14AcMyuOVj4RkXhT6zUvfXLpLAD+6nMzU0zOQZhZKbAYWAMUOufqvacagEJveTpQE/ayA946EZFxYfD8wyXz8pmeM5EXdo7xAmFmk4E/AF9yzrWHP+ecc4A7wZ93k5mtM7N1TU3+t9GJiIyW+kA3yUlGQWYay8vzebmqmZ7+oG95ologzCyVUHH4jXPuEW9142DTkXc/WCJrgZlhL5/hrXsb59xdzrkK51xFfv5xJ0QSEUkYdYEuirLSSUlOYvn8Ao70Bln7ZqtveaLZi8mAu4Edzrkfhz31OHCDt3wD8FjY+r/xejMtBdrCmqJERMa82kAX03LSAbhgzhQmpCT52t01mkcQ7wY+BVxqZpu821XA94H3mFklcJn3GOBJYC9QBfwS+HwUs4mIxJ26ti6m5UwEYNKEFC6YPYWVPp6HiNqc1M65lwAb5ukVQ2zvgC9EK4+ISDwbGHA0tHVTnD3x6Lrl8/P51hNvsK/5CKVTM2KeSVdSi4jEgebDPfQFHdO9JiaA5eUFgH/dXVUgRETiwOA1EINNTACzpmQwOz+DF3b502NTBUJEJA7UBUJXTYcXCIDl8wt4bW8Lnb39Mc+kAiEiEgfqhjiCALi0vIDe/gFeqWqJeSYVCBGROFDX1kXGhGSy0t/ed6iiNJeMCcm+dHdVgRARiQN1gVAX19AlZG9JS0nm3XOnsnLnQUKdPWNHBUJEJA7UBbrf0bw06NLyAurautndeDimmVQgRETiwOARxFCWzQ91d411M5MKhIiIz7r7grQc6X3bNRDhirLTWVCcFfPRXVUgRER8NjhRUPhV1MdaXp7P+v2ttHX1xSqWCoSIiN+G6+Iabvn8AoIDjtWVsbtoTgVCRMRng1dRTx+hQCwuySVnUiorYzhXtQqEiIjP6gJdmEFhdtqw2yQnGReX5fPi7oMMDMSmu6sKhIiIz+oD3eRPTiMtJXnE7ZaX59N8uJettW0xyaUCISLis7q2LopHaF4adMm8Asxi191VBUJExGe1ga5hu7iGy8uYwKKZOTGbREgFQkTER8650EVyI3RxDXfp/AI2H2ijqaMnyslUIEREfBXo7KO7b2DELq7hBicRenF39HszqUCIiPjorYmCjt/EBLCgOIvpOROp914XTVGbk1pERI4vkovkwiUlGS9+bRkpydH/fq8jCBERH51ogQBiUhxABUJExFd1bd1MSEliSsYEv6O8gwqEiIiP6gJdTB9ioqB4oAIhIuKjukAXxdmRnaCONRUIEREfjTSTnN9UIEREfNIXHKCxQwVCRESO0djejXNENMyGH1QgRER8UhcIzSSnIwgREXmbwWsgRppq1E8qECIiPjnRYTZiTQVCRMQndYEucielMmlCfI56pAIhIuKT+rb47cEEKhAiIr6pC3SpQIiIyDvVBrqYFqdXUYMKhIiILzq6++jo7tcRhIiIvF19W3xfAwEqECIivqg9iXkgYk0FQkTEB4MXyU1XgRARkXB1gS5Skoz8zDS/owwragXCzO4xs4Nmti1s3bfMrNbMNnm3q8Keu9XMqsxsl5m9N1q5RETiQV2gm8KsdJKT4m+ioEHRPIL4FXDFEOt/4pxb5N2eBDCzBcD1wBnea35qZslRzCYi4qvBmeTiWdQKhHNuFXAows2vBh50zvU4594EqoDzopVNRMRvdW1dcTsG0yA/zkF80cy2eE1Qud666UBN2DYHvHUiImNOcMDREOfDbEDsC8TPgDnAIqAe+NGJ/gAzu8nM1pnZuqamptHOJyISdc2He+gLOopVIN7inGt0zgWdcwPAL3mrGakWmBm26Qxv3VA/4y7nXIVzriI/Pz+6gUVEouCtLq5qYjrKzIrDHn4IGOzh9DhwvZmlmdlpQBnweiyziYjESrzPJDcoaoOQm9kDwDJgqpkdAP4vsMzMFgEO2Af8PYBzbruZ/Q54A+gHvuCcC0Yrm4iIn+oS4CpqiGKBcM59bIjVd4+w/e3A7dHKIyISL2oDXUxOSyErPdXvKCPSldQiIjEWmgcivs8/gAqEiEjMxftMcoNUIEREYizeZ5IbpAIhIhJD3X1BWo70xv0wG6ACISISU4M9mIrjeKrRQSoQIiIxlAgzyQ1SgRARiaHaBJgoaJAKhIhIDNUFujCDwiw1MYmISJi6QBcFmWlMSIn/j9/4TygiMobUBbopzo7/5iVQgRARiam6tvifSW6QCoSISIw45xJmmA1QgRARiZnWzj66+wYSoosrqECIiMRMogzzPUgFQkQkRgavgZimk9QiIhJuY3WAlCRjdn6G31EiogIhIhIjL1U1sWRWLhlpUZurbVSpQIiIxEDL4R621bZzcdlUv6NETAVCRCQGXt7TAsCFZfk+J4mcCoSISAys3t1E9sRUFk7P9jtKxCIqEGZ2s5llWcjdZrbBzC6PdjgRkbHAOcdLVc28e+4UkpPM7zgRi/QI4jPOuXbgciAX+BTw/ailEhEZQ/Y0Haa+rZsL5yZO8xJEXiAGS95VwH3Oue1h60REZASrK5sBuCiBTlBD5AVivZk9Q6hAPG1mmcBA9GKJiIwdqyubKZ0yiZl5k/yOckIi7Yx7I7AI2Ouc6zSzKcCnoxdLRGRs6O0f4LW9LVy7ZIbfUU5YpEcQDlgA/KP3OANIjOEIRUR8tKG6lc7eIBcmWPMSRF4gfgpcAHzMe9wB3BmVRCIiY8hLlc0kJxkXzJnid5QTFmkT0/nOuSVmthHAOddqZhOimEtEZExYXdnEopk5ZKWn+h3lhEV6BNFnZsmEmpows3x0klpEZESBzl621LYlXO+lQZEWiDuAR4ECM7sdeAn4btRSiYiMAS9XteBc4nVvHRRRE5Nz7jdmth5YQej6h2ucczuimkxEJMG9VNVEZloKZ8/I8TvKSRmxQJhZXtjDg8AD4c855w5FK5iISCJzzrFqdzMXzJlCSnJiDnt3vCOI9YTOOxhQArR6yzlANXBaVNOJiCSofS2d1Aa6+OyyOX5HOWkjljXn3GnOudnAc8AHnHNTnXNTgPcDz8QioIhIIlpd2QTARXMT8/wDRH6Seqlz7snBB865p4B3RSeSiEjiW13ZzMy8icyakljDa4SLtEDUmdm/mFmpd7sNqItmMBGRRNUXHODVPS1cODcfs8Qd1zTSAvExIJ9QV9dHgQLeuqpaRETCbK4JcLinP6GmFx1KpN1cDwE3e6O4Oufc4ejGEhFJXKsqm0kyeNecxC4Qkc4ot9AbZmMbsN3M1pvZmdGNJiKSmF6qbOKsGTlkT0q84TXCRdrE9AvgK865Wc65WcBXgbtGeoGZ3WNmB81sW9i6PDN71swqvftcb72Z2R1mVmVmW8xsycn+QiIifmrr6mNTTSBhr54OF2mByHDOrRx84Jz7K6Ehv0fyK+CKY9bdAjzvnCsDnvceA1wJlHm3m4CfRZhLRCSuvLqnhQEHF5Ul1vSiQ4m0QOw1s2+E9WL6F2DvSC9wzq0Cjr3S+mrgXm/5XuCasPW/diGvATlmVhxhNhGRuLG6somMCcksLknM4TXCRVogPkOoF9Mj3m2qt+5EFTrn6r3lBqDQW54O1IRtd8BbJyKSUF6qCg2vkZqgw2uEi7QXUyvebHLesN8Zzrn2U3lj55wzM3eirzOzmwg1Q1FSUnIqEURERlV1Syf7Wzr59LtK/Y4yKiLtxfRbM8syswxgK/CGmX3tJN6vcbDpyLs/6K2vBWaGbTfDW/cOzrm7nHMVzrmK/PzEb+MTkbFjdZU3vMa8sfHZFOkx0ALviOEa4ClCg/R96iTe73HgBm/5BuCxsPV/4/VmWgq0hTVFiYgkhNW7m5mWnc7sqcfrw5MYIi0QqWaWSqhAPO6c68ObXW44ZvYA8Cow38wOmNmNwPeB95hZJXCZ9xjgSUInvauAXwKfP+HfRETER/3BAV7Z08xFZYk9vEa4SOek/gWwD9gMrDKzWcCI5yCcc8MNxbFiiG0d8IUIs4iIxJ1X9rTQ3t3P8vKx0bwEkZ+kvoPQtKOD9pvZ8uhEEhFJPE9sriMzLYVl8wv8jjJqjjej3Cedc/eb2VeG2eTHUcgkIpJQevqD/GV7A5efUUR6arLfcUbN8Y4gBs+0ZEY7iIhIonpxVxMd3f184OyxdX3viAXCOfcL7/5fYxNHRCTxPL65jryMCbw7gWePG0qk10HMNrMnzKzJG4DvMTObHe1wIiLx7khPP8/taOSqhUVj4urpcJH+Nr8FfgcUA9OA3wMPRCuUiEiieG5HI919A3zgrGl+Rxl1kRaISc65+5xz/d7tfiA9msFERBLBE5vrKM5O59zSPL+jjLpIC8RTZnaLN5LrLDP7OvCkN7/D2NsrIiIRCHT28uLuJt5/VjFJSWPj4rhwkV4o9xHv/u+PWX89oSuqdT5CRMadp7c30Bd0fODssde8BJFfKHdatIOIiCSaxzfXUTplEgunZ/sdJSpGbGLympIGl6875rnvRiuUiEi8O9jRzat7Wvjg2dPGzNhLxzreOYjrw5ZvPea5Y6cTFREZN57cUs+AY8w2L8HxC4QNszzUYxGRcePxzXWUF2VSVjh2B5o4XoFwwywP9VhEZFyoOdTJhuoAH1w0do8e4Pgnqc82s3ZCRwsTvWW8x7oOQkTGpT9tCc1nNhYvjgt3vLGYxs6whCIio+TxzXUsLslhZt4kv6NE1dgaOEREJMqqDnawo76dD47hk9ODVCBERE7A45vrSTJ438KxNbT3UFQgREQi5Jzjic11LJ09hYKssX8aVgVCRCRC22rbebP5yLhoXgIVCBGRiD2xpY7UZOOKM4v8jhITKhAiIhEYGAg1L11clk/OpAl+x4kJFQgRkQis299KfVv3mL84LpwKhIhIBP64qZb01CQuO73Q7ygxowIhInIcm2oCPLS2hqvPnk5GWqTT6CQ+FQgRkREc6ennSw9upCgrnf/zvtP9jhNT46cUioichO/8+Q32H+rkgb9bSvbEVL/jxJSOIEREhvH09gYeeL2Gz14yh6Wzp/gdJ+ZUIEREhnCwvZtb/rCFM6dn8eXL5vkdxxcqECIixxgYcPzTw1vo6gvy/z66mAkp4/Ojcnz+1iIiI7j31X2s2t3Ebe9bwNyCyX7H8Y0KhIhImF0NHXzvqZ2sKC/gk+eX+B3HVyoQIiKenv4gNz+4kaz0FH7w4bMwM78j+UrdXEVEPD98ehc7Gzq4+4YKpk5O8zuO73QEISICvFzVzC9Xv8knl5awYhwNpzESFQgRGfdaj/Ty1d9tZk5+BrddtcDvOHFDBUJExrXNNQGu+enLtBzp4T+vX8zECcl+R4obKhAiMi4NDDjuWrWHa3/2Cv1BxwN/t5Qzp2f7HSuu6CS1iIw7TR09fPX3m1m1u4krzijiB9eeRfak8TXOUiR8KRBmtg/oAIJAv3OuwszygIeAUmAf8BHnXKsf+URk7Fpd2cSXH9pMR3cft3/oTD5+Xsm47846HD+bmJY75xY55yq8x7cAzzvnyoDnvcciIqOiLzjA957awafufp3cSak8/sUL+cT5s1QcRhBPTUxXA8u85XuBvwL/7FcYERk7qls6+YcHN7K5JsDHzy/hG+9boJPREfCrQDjgGTNzwC+cc3cBhc65eu/5BkAdkUXklDW2d/P+/1oNwE8/sYSrFhb7nChx+FUgLnTO1ZpZAfCsme0Mf9I557zi8Q5mdhNwE0BJyfgeJ0VEju+B16vp6OnnmS9dTFlhpt9xEoov5yCcc7Xe/UHgUeA8oNHMigG8+4PDvPYu51yFc64iPz8/VpFFJAH1Bwd4aG0NF5XlqzichJgXCDPLMLPMwWXgcmAb8Dhwg7fZDcBjsc4mImPLX3c1Ud/WzcfPU2vDyfCjiakQeNTrOZAC/NY59xczWwv8zsxuBPYDH/Ehm4iMIb9Zs5/CrDRWnF7gd5SEFPMC4ZzbC5w9xPoWYEWs84jI2HSgtZO/7m7iH5bPJTVZg0acDO01ERmTHlpbgwEfVfPSSVOBEJExpy84wINra1g2v4DpORP9jpOwVCBEZMx5fkcjTR09fGKcTxl6qlQgRGTM+c2aaqZlp7Nsvk5OnwoVCBEZU/a3HGF1ZTMfPbeE5CSNs3QqVCBEZEx54PUakpOMj5470+8oCU8FQkTGjN7+AX6/roYV5QUUZaf7HSfhqUCIyJjx9PYGWo708nGdnB4VKhAiMmb8dk01M3IncnGZxmkbDfE0H4SIjBHNh3v40TO7SEtJ5sPnzOCMaVlRn5hnT9NhXt3bwtfeO58knZweFSoQIjKqXqps5su/20RbVx84+NUr+ygvyuS6iplcvWgaUyenReV9H1hTTUqScV3FjKj8/PFIBUJERkVfcIAfP7ubn7+4hzn5k7nvxvMoykrniS31PLz+AN/+0xt878kdLC8v4LpzZrC8vGDUxkjq7gvy8IYDXH5GIQWZOjk9WlQgROSU1Rzq5B8f3MjG6gAfO28m33z/GUen9PzU0ll8auksdjd28If1B3hkYy3PvtHIlIwJXLN4On970WkUZ5/acBh/2dZAoLOPT5w/azR+HfGYc0NO3JYQKioq3Lp16/yOITKu/WlLHbf+YSsYfO9/LeT9Z00bcfv+4ACrKpt4eP0Bnn2jkSQzPnPhaXz2kjlkT0w9qQzX/fwVmjp6eOGry3T+IQJmtt45V3G87XQEISInpbO3n3974g0eXFvD4pIc7rh+MTPzJh33dSnJSVxaXsil5YXUHOo82iz1wOvVfHH5XD51wSzSUpIjzrG7sYO1+1q59cpyFYdRpm6uInLCdtS388H/fpmH1tXw+WVz+N3fXxBRcTjWzLxJ/OSji/jTP1zIwunZfOfPO7j0hy/y6MYDDAyM3LrRHxxgV0MHdzxfyYTkJD58jk5OjzYdQYgkiPq2Ll6uauEDZxef0Dfs0RQccPxi1R5+8uxuciZN4P4bz+fdc6ee8s89Y1o29914Pi9VNvO9p3bw5Yc2c9eqN7nlynIuLpvKkd4gO+vbeaO+nTfq2tle186uxg56+wcA+OTSEqZEqXfUeKZzECJxzjnHw+sP8G9PvEFHTz+zp2bwnWvO5F2j8MF8IvY1H+Grv9/M+v2tXLWwiO9cs5C8jAmj/j4DA44nttTxw2d2UXOoi8KsNBrbe44+nzsplTOmZbNgWhYLirNYMC2LufmT1bx0AiI9B6ECIXKCnHMEBxxB5xgYgKBzpCZbVL7VH2zv5tZHtvL8zoOcV5rHx88v4SfP7WZ/SycfWjyd2953etSuKxjknOP+NdV89887SE02vn3NmXzw7GlRv/Ctpz/Ib9dUs6E6wLyCyaGCMC2Loqz0qL/3WKcCEQeCA47tdW2UFWQe7fInieWFnY388x+20tbZFyoIzjHUf5kkg9IpGcwvyqS8KMu7z6Qkb9JJfbN1zvH45jq++dh2uvuCfP2Kcj79rlKSkozuviB3rqzi5y/uYdKEFG65spyPVsyMyjfo+rYuvv7wFlZXNnPxvHz+/dqzNAjeGKAC4bPNNQG+8dg2thxoY2JqMpeWF3DVwmKWl+czaULsTv0456gNdLFuXyu1gS7m5E8+pQ+ueNHW2cfzOxspzp5IeVEmuVFo6njg9Wpue3Qr5UVZXDwvn+QkSDYjKclIMiP56D0c7gmyu6GDnQ3t7D/UebSITJqQTFlhJuWFmZw5PYvFJbmUF2WSMsIFYs2He7jt0a08vb2RxSU5/Oi6s5mdP/kd21UdPMxtj25lzZuHOGdWLrd/6EzKi7JG5Xd3zvHYpjq++dg2+oKO2953Op84v0Tf3McIFQiftHX28R/P7OQ3a6qZOjmNLy6fS9XBwzy1rYHmwz2kpya9VSzmF5CRNrrFoj84wM6GDtbtO8Ta/a2s39dKQ3v3O7YL/+AqL848+s03Gm3Ko6mnP8h9r+7nv16oCg3l4CnMSmN+URbl3jf3+UWZzC2YfFLNPs45fvLsbu54oYpl8/O58+NLTujfqbO3n92Nh9nV0M7Ohg521ocKR2tnKO/E1GTOmpHN4pJclpTksLgkl/zMUDPRk1vr+Zc/buNwdz9fuXwef3fR7BEnvXHO8YcNtXz3yR20d/Vx40WncfOKspP+EjIw4Nh8IMAvV+/lya0NnDMrlx9ddzalUzNO6udJfFKBiDHnHI94/1FbO3u54V2lfPk988hKD134ExxwrN13iD9vqX9bsVg2r4DLzyikYlYeM/MmnvA3tCM9/WysDrBu/yHW7WtlY3UrR3qDABRnp1NRmse5pbmcMyuXkrxJ7Gk6MuwHF8DZM7K55crTuWDOlNHbOaPAOceft9bz73/ZRfWhTi4qm8rNK8o40ht82+9TdfAwvcFQz5bkJGPRzBxuXlHGRWVTI9q3fcEBbn1kKw+vP8BHKmZw+4cWjspwEM45DrR2saG6lY3VATZWt7K9rp1+ryvnzLyJFGWls3ZfK2fNyOZH151NWWFmxD+/9Ugv339qJw+tq2FiajIVpbksnT2Fd82ZwsLp2SMesRzu6eelyiae33GQlbsO0ny4lwnJSREVKElMKhCn6Nt/eoPfr6thcUkuFbNyqSjNY9HMnCHPJexq6OAbf9zG6/sOsaQkh29fcyZnTMse9mcPFosnt4aKRVNHqIfG1MkTWFySy+KSHJaU5HLWjOx3fBNsbO9m7b5QMVi3/xA76jsIDjjMoLwoy8sayjs95/jDFzjnaOroYWdDB9vr2rnv1X3UtXVz2ekF3HJlOXMLIv+QikTNoU5WVTZxpKefRTNzWTg9+7jnZ9btO8TtT+5gY3WA8qJMbr3qdC6ZN/Rwzv3BAfa1HGGHV/z+uLGO2kAXF8yewtevmM/iktxh3+dwTz+fu389qyub+dJlZdy8oiyqTSrdfUG21bYdLRq7Gzu4ZtF0PrtszkkXpfX7W3licx2v7mlhV2MHAJPTUji3NJcL5kzhgtlTWTAti7pAFy/sPMhzOxpZs/cQvcEBstJTWDa/gBWnF3DJvHxyJsX30aScPBWIU1AX6OKS/1jJ/KJM+vrd0f9oKUnGGdOzqZiVy7mluaG+26/t5+6X3iQzPYVbryznunNO7GRhcMCxq6Hjbd8s9zYfAULfgMuLMllcksORniDr9h+i5lAXEGqmWDQzJ3R0UJrH4pKco0crp6K7L8g9L7/Jz1buobMvyPXnzuRLl8072gRyoo709PPqnhZWVzaxqrKZN73fbVBKknF6cdbRppYlJblHj6TebD7CD57ayV+2N1CYlcZX3zOfa8+ZcULfaAd7wvz3C1W0HOnlvWcU8rX3zn9H4TvY3s2nf7WWnQ0dfO9DC/nIGJiusvlwD6/tbeHVPS28ureFvU2hfT8xNZmuvtBR5uz8DFaUF7Di9ELOmZU7aoPnSXxTgTgF//rEdu57dT9//doyZuROItDZy4bq1tC39n2tbDoQOHqBDsD1587k61eUj1r7feuRXjbVBNhQ3cqG6lY2VQeYlJZy9EimYlYuC6ZlRfU/c8vhHu54vpLfrKkmLSWJzy2bw40Xzj7ut/2u3iB7mg6zqrKJVbubWL+/lb6gY2JqMhfMmcJFZVO5eF4+ORNT2VgdOFoYNx8I0Ok1jU2dPIH5RZms2XuICSlJfPaSOfztRaed0sn9wz393PPSm9y1ai+dvf1cu2QGX3rPPKbnTKTq4GFuuOd1Wjt7ufMTS1g+v+Ck3yeeNbZ389reFtbvb6UkbxIrTi/kNJ1bGJdUIE5S8+EeLvzBC7z/rGn88Lqzh9ympz/Ittp2thwIsGhmzojNFqNh8N/Ijx4ke5sO84O/7OTp7Y0UZaXz5feUMSN3EnWBLhrauqlr66ahrYv6tm7q27rfduL49OIsLp43lUvK8jmnNHfEE8b9wQF2NXYcLRrba9upKM3l5svKRnX45kNHevnpyip+/dp+cHDtOTN4als9KUlJ/M//PpeFM4ZvGhQZK1QgTtJ/PL2Tn/51D8995RLmDNG1cLxau+8Q3/nzDjbXBN62fkrGBIqy0ynOnkhxdjpF2enMzJvE0tl5cT0uf22gi/98bjcPrz9A6ZQM7v3MeSc1lpBIIlKBOAltXX1c+P0XuHhePnd+Ysmo/dyxwjnHa3sPYRbqIVWYlU56amJfAFgb6CJ7YiqTR7m7sUg803DfJ+H+1/bT0dPP55bN8TtKXDKzuOv+eqoi6eklMl6py4KnqzfI3S+9yfL5+Zw5Xe3QIiIqEJ4HXq/m0JFevrB8rt9RRETiggoE0Ns/wF2r9nL+aXlUlOb5HUdEJC6oQACPbDhAQ3u3jh5ERMKM+wLRHxzgZy/uYeH0bC4qi+0ELCIi8WzcF4g/b61nf0snX1g+V0P7UCw2AAAGEklEQVQZi4iEGdcFYmDA8dOVeygrmMzlCwr9jiMiElfGdYF4fudBdjV28PnlcxJ68hwRkWgYtwXCOcd/r6xiZt5EPnDWNL/jiIjEnXFbIF7Z08LmmgCfvWTOiJOpiIiMV3H3yWhmV5jZLjOrMrNbovU+d66soiAzjWuXzIjWW4iIJLS4KhBmlgzcCVwJLAA+ZmYLRvt9NlS38sqeFm66eHbCDzYnIhItcVUggPOAKufcXudcL/AgcPVov4lzcFHZVD52Xslo/2gRkTEj3grEdKAm7PEBb91RZnaTma0zs3VNTU0n9SbnzMrlvhvPJ0NDPIuIDCveCsRxOefucs5VOOcq8vOHnrheREROXbwViFogfLb4Gd46ERGJsXgrEGuBMjM7zcwmANcDj/ucSURkXIqrRnjnXL+ZfRF4GkgG7nHObfc5lojIuBRXBQLAOfck8KTfOURExrt4a2ISEZE4oQIhIiJDUoEQEZEhmXPO7wwnzcyagP0n+fKpQPMoxom2RMqbSFkhsfImUlZIrLyJlBVOLe8s59xxLyRL6AJxKsxsnXOuwu8ckUqkvImUFRIrbyJlhcTKm0hZITZ51cQkIiJDUoEQEZEhjecCcZffAU5QIuVNpKyQWHkTKSskVt5EygoxyDtuz0GIiMjIxvMRhIiIjGBcFohYTWt6ssxsn5ltNbNNZrbOW5dnZs+aWaV3n+tjvnvM7KCZbQtbN2Q+C7nD29dbzGxJHGT9lpnVevt3k5ldFfbcrV7WXWb23lhm9d5/ppmtNLM3zGy7md3srY+7/TtC1rjcv2aWbmavm9lmL++/eutPM7M1Xq6HvIFCMbM073GV93xpHGT9lZm9GbZvF3nro/N34JwbVzdCgwDuAWYDE4DNwAK/cx2TcR8w9Zh1/w7c4i3fAvzAx3wXA0uAbcfLB1wFPAUYsBRYEwdZvwX80xDbLvD+HtKA07y/k+QY5y0GlnjLmcBuL1fc7d8Rssbl/vX20WRvORVY4+2z3wHXe+t/DnzOW/488HNv+XrgoTjI+ivgw0NsH5W/g/F4BBGTaU2j4GrgXm/5XuAav4I451YBh45ZPVy+q4Ffu5DXgBwzK45N0mGzDudq4EHnXI9z7k2gitDfS8w45+qdcxu85Q5gB6FZFeNu/46QdTi+7l9vHx32HqZ6NwdcCjzsrT923w7u84eBFWZmPmcdTlT+DsZjgTjutKZxwAHPmNl6M7vJW1fonKv3lhuAQn+iDWu4fPG6v7/oHYrfE9ZcF1dZvSaNxYS+Pcb1/j0mK8Tp/jWzZDPbBBwEniV0FBNwzvUPkeloXu/5NmCKX1mdc4P79nZv3/7EzNKOzeoZlX07HgtEIrjQObcEuBL4gpldHP6kCx1Txm33s3jPB/wMmAMsAuqBH/kb553MbDLwB+BLzrn28Ofibf8OkTVu969zLuicW0RotsrzgHKfIw3r2KxmdiZwK6HM5wJ5wD9HM8N4LBBxP62pc67Wuz8IPEroD7lx8JDRuz/oX8IhDZcv7va3c67R+883APySt5o54iKrmaUS+sD9jXPuEW91XO7fobLG+/4FcM4FgJXABYSaYwbnxgnPdDSv93w20BLjqOFZr/Ca9Zxzrgf4H6K8b8djgYjraU3NLMPMMgeXgcuBbYQy3uBtdgPwmD8JhzVcvseBv/F6WSwF2sKaSnxxTNvshwjtXwhlvd7rvXIaUAa8HuNsBtwN7HDO/Tjsqbjbv8Nljdf9a2b5ZpbjLU8E3kPovMlK4MPeZsfu28F9/mHgBe/oza+sO8O+JBihcyXh+3b0/w6ieSY+Xm+EzvjvJtT+eJvfeY7JNptQT4/NwPbBfITaPp8HKoHngDwfMz5AqOmgj1Bb543D5SPUq+JOb19vBSriIOt9XpYt3n+s4rDtb/Oy7gKu9GHfXkio+WgLsMm7XRWP+3eErHG5f4GzgI1erm3AN731swkVqirg90Catz7de1zlPT87DrK+4O3bbcD9vNXTKSp/B7qSWkREhjQem5hERCQCKhAiIjIkFQgRERmSCoSIiAxJBUJERIakAiEiIkNSgRARkSGpQIiIyJD+P3KYuDacWUuoAAAAAElFTkSuQmCC\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "episodes = np.arange(36) * 10\n",
    "plt.plot(episodes, average_rewards)\n",
    "plt.ylabel('Average rewards per 10 episodes')\n",
    "plt.ylabel('Episodes')\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.5.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
